{
  "cells": [
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "-4DFtQNDYao1"
      },
      "source": [
        "# Intro to TorchRec\n",
        "\n",
        "Frequently, when building recommendation systems, we want to represent entities like products or pages with embeddings. For example, see Meta AI's [Deep learning recommendation model](https://arxiv.org/abs/1906.00091), or DLRM. As the number of entities grow, the size of the embedding tables can exceed a single GPU’s memory. A common practice is to shard the embedding table across devices, a type of model parallelism. To that end, **TorchRec introduces its primary API called [`DistributedModelParallel`](https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.model_parallel.DistributedModelParallel), or DMP. Like pytorch’s DistributedDataParallel, DMP wraps a model to enable distributed training.**"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "hBgIy9eYYx35"
      },
      "source": [
        "## **Installation**\n",
        "Requirements:\n",
        "- python >= 3.9\n",
        "\n",
        "We highly recommend CUDA when using TorchRec. If using CUDA:\n",
        "- cuda >= 11.8\n",
        "\n",
        "Installing TorchRec will also install [FBGEMM](https://github.com/pytorch/fbgemm), a collection of CUDA kernels and GPU enabled operations to run "
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {
        "id": "sFYvP95xaAER"
      },
      "outputs": [],
      "source": [
        "!pip3 install --pre torch --index-url https://download.pytorch.org/whl/nightly/cu121 -U\n",
        "!pip3 install fbgemm_gpu --index-url https://download.pytorch.org/whl/nightly/cu121\n",
        "!pip3 install torchmetrics==1.0.3\n",
        "!pip3 install torchrec --index-url https://download.pytorch.org/whl/nightly/cu121"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "HWBOrwVSnrNE"
      },
      "source": [
        "## **Overview**\n",
        "This tutorial will cover three pieces of TorchRec - the `nn.module` `EmbeddingBagCollection`, the `DistributedModelParallel` API, and the datastructure `KeyedJaggedTensor`."
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "udsN6PlUo1zF"
      },
      "source": [
        "### Distributed Setup\n",
        "We setup our environment with torch.distributed. For more info on distributed, see this [tutorial](https://pytorch.org/tutorials/beginner/dist_overview.html)\n",
        "\n",
        "Here, we use one rank (the colab process) corresponding to our 1 colab GPU."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 1,
      "metadata": {
        "id": "4-v17rxkopQw"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "/home/ylgh/anaconda3/envs/torchrec/lib/python3.7/site-packages/tqdm/auto.py:22: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
            "  from .autonotebook import tqdm as notebook_tqdm\n"
          ]
        }
      ],
      "source": [
        "import os\n",
        "import torch\n",
        "import torchrec\n",
        "import torch.distributed as dist\n",
        "\n",
        "os.environ[\"RANK\"] = \"0\"\n",
        "os.environ[\"WORLD_SIZE\"] = \"1\"\n",
        "os.environ[\"MASTER_ADDR\"] = \"localhost\"\n",
        "os.environ[\"MASTER_PORT\"] = \"29500\"\n",
        "\n",
        "# Note - you will need a V100 or A100 to run tutorial as as!\n",
        "# If using an older GPU (such as colab free K80), \n",
        "# you will need to compile fbgemm with the appripriate CUDA architecture\n",
        "# or run with \"gloo\" on CPUs \n",
        "dist.init_process_group(backend=\"nccl\")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZdSUWBRxoP8R"
      },
      "source": [
        "### From EmbeddingBag to EmbeddingBagCollection\n",
        "Pytorch represents embeddings through [`torch.nn.Embedding`](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) and [`torch.nn.EmbeddingBag`](https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html). EmbeddingBag is a pooled version of Embedding.\n",
        "\n",
        "TorchRec extends these modules by creating collections of embeddings. We will use [`EmbeddingBagCollection`](https://pytorch.org/torchrec/torchrec.modules.html#torchrec.modules.embedding_modules.EmbeddingBagCollection) to represent a group of EmbeddingBags.\n",
        "\n",
        "Here, we create an EmbeddingBagCollection (EBC) with two embedding bags. Each table, `product_table` and `user_table`, is represented by 64 dimension embedding of size 4096. Note how we initially allocate the EBC on device \"meta\". This will tell EBC to not allocate memory yet."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 2,
      "metadata": {
        "id": "Iz_GZDp_oQ19"
      },
      "outputs": [],
      "source": [
        "ebc = torchrec.EmbeddingBagCollection(\n",
        "    device=\"meta\",\n",
        "    tables=[\n",
        "        torchrec.EmbeddingBagConfig(\n",
        "            name=\"product_table\",\n",
        "            embedding_dim=64,\n",
        "            num_embeddings=4096,\n",
        "            feature_names=[\"product\"],\n",
        "            pooling=torchrec.PoolingType.SUM,\n",
        "        ),\n",
        "        torchrec.EmbeddingBagConfig(\n",
        "            name=\"user_table\",\n",
        "            embedding_dim=64,\n",
        "            num_embeddings=4096,\n",
        "            feature_names=[\"user\"],\n",
        "            pooling=torchrec.PoolingType.SUM,\n",
        "        )\n",
        "    ]\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "# FBGEMM Optimizations - Batching and Fusion\n",
        "\n",
        "TorchRec provides abstractions over [FBGEMM](https://github.com/pytorch/FBGEMM/tree/main/fbgemm_gpu) kernels that provide efficient implementations of the canonical nn.EmbeddingBags. Two of the optimizations that can be done are \n",
        "\n",
        "* Table batching, which allows you to look up multiple embeddings with one kernel call.\n",
        "* Optimizer Fusion, which allows the module to update itself given the canonical pytorch optimizers and arguments.\n",
        "\n",
        "This can be accessed by using the [fuse_embedding_optimizer](https://github.com/pytorch/torchrec/blob/main/torchrec/modules/fused_embedding_modules.py#L271) wrapper, which will replace embedding modules with their batched and fused counter parts. You can also directly use these efficient counterparts, take a look at torchrec.modules.fused_embedding_modules.\n",
        "\n",
        "To quantitatively see the performance gains of this, see our [benchmarks](https://github.com/pytorch/torchrec/blob/main/benchmarks/README.md).\n",
        "\n",
        "Note that this step is optional - the following steps can also be applied to the non-optimizer EmbeddingBagCollection."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 18,
      "metadata": {},
      "outputs": [],
      "source": [
        "from torchrec.optim.apply_optimizer_in_backward import apply_optimizer_in_backward\n",
        "\n",
        "apply_optimizer_in_backward(\n",
        "    optimizer_class=torch.optim.SGD,\n",
        "    params=ebc.parameters(),\n",
        "    optimizer_kwargs={\"lr\": 0.02},\n",
        ")"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "7m0_ssVLFQEH"
      },
      "source": [
        "### DistributedModelParallel\n",
        "Now, we’re ready to wrap our model with [`DistributedModelParallel`](https://pytorch.org/torchrec/torchrec.distributed.html#torchrec.distributed.model_parallel.DistributedModelParallel) (DMP). Instantiating DMP will:\n",
        "\n",
        "1. Decide how to shard the model. DMP will collect the available ‘sharders’ and come up with a ‘plan’ of the optimal way to shard the embedding table(s) (i.e, the EmbeddingBagCollection)\n",
        "2. Actually shard the model. This includes allocating memory for each embedding table on the appropriate device(s).\n",
        "\n",
        "In this toy example, since we have two EmbeddingTables and one GPU, TorchRec will place both on the single GPU.\n",
        "\n",
        "To learn more about sharding, see our [sharding tutorial](https://pytorch.org/tutorials/advanced/sharding.html).\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 19,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "arW0Jf6qEl-h",
        "outputId": "66c515f1-b432-4b8f-abca-40f346942fe4"
      },
      "outputs": [
        {
          "name": "stderr",
          "output_type": "stream",
          "text": [
            "WARNING:root:Could not determine LOCAL_WORLD_SIZE from environment, falling back to WORLD_SIZE.\n"
          ]
        },
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "DistributedModelParallel(\n",
            "  (_dmp_wrapped_module): ShardedFusedEmbeddingBagCollection(\n",
            "    (_input_dists): ModuleList()\n",
            "    (_lookups): ModuleList(\n",
            "      (0): GroupedPooledEmbeddingsLookup(\n",
            "        (_emb_modules): ModuleList(\n",
            "          (0): BatchedFusedEmbeddingBag(\n",
            "            (_emb_module): SplitTableBatchedEmbeddingBagsCodegen()\n",
            "          )\n",
            "        )\n",
            "        (_score_emb_modules): ModuleList()\n",
            "      )\n",
            "    )\n",
            "    (_output_dists): ModuleList(\n",
            "      (0): TwPooledEmbeddingDist(\n",
            "        (_dist): PooledEmbeddingsAllToAll()\n",
            "      )\n",
            "    )\n",
            "  )\n",
            ")\n",
            "{'': {'product_table': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:0/cuda:0)])), 'user_table': ParameterSharding(sharding_type='table_wise', compute_kernel='batched_fused', ranks=[0], sharding_spec=EnumerableShardingSpec(shards=[ShardMetadata(shard_offsets=[0, 0], shard_sizes=[4096, 64], placement=rank:0/cuda:0)]))}}\n"
          ]
        }
      ],
      "source": [
        "model = torchrec.distributed.DistributedModelParallel(ebc, device=torch.device(\"cuda\"))\n",
        "print(model)\n",
        "print(model.plan)"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {},
      "source": [
        "## Sharders and Quantized Comms\n",
        "\n",
        "By default, DistributedModelParallel will identify which Sharder to use for your embedding module. In the above case, it creates a default EmbeddingBagCollectionSharder.\n",
        "\n",
        "However, you may also specify your own Sharder; by doing so, you can set additional sharding parameters. For example, you can specify the quantized/mixed precision config. \n",
        "\n",
        "Applying quantization and mixed precision to collective calls as part of distributed training is often used to increase the model's training throughput, while at the same time, not significantly sacrificing model quality. \n",
        "\n",
        "TorchRec provides helper functions to construct communication codecs, based on the [FBGEMM Qcomm library](https://github.com/pytorch/FBGEMM/blob/main/fbgemm_gpu/fbgemm_gpu/quantize_comm.py).\n",
        "\n",
        "Below, we create a sharder that uses FP16 mixed precision for the forward pass (when passing embedding tensors around in a collective call, first cast them to FP16). And similarly, for the backwards pass, cast tensors within collective calls to BF16."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [
        {
          "ename": "",
          "evalue": "",
          "output_type": "error",
          "traceback": [
            "\u001b[1;31mRunning cells with 'Python 3.7.13 ('torchrec': conda)' requires ipykernel package.\n",
            "\u001b[1;31mRun the following command to install 'ipykernel' into the Python environment. \n",
            "\u001b[1;31mCommand: 'conda install -n torchrec ipykernel --update-deps --force-reinstall'"
          ]
        }
      ],
      "source": [
        "from torchrec.distributed.fbgemm_qcomm_codec import get_qcomm_codecs_registry, QCommsConfig, CommType\n",
        "from torchrec.distributed.embeddingbag import EmbeddingBagCollectionSharder\n",
        "\n",
        "sharder = EmbeddingBagCollectionSharder(\n",
        "    qcomm_codecs_registry=get_qcomm_codecs_registry(\n",
        "            qcomms_config=QCommsConfig(\n",
        "                forward_precision=CommType.FP16,\n",
        "                backward_precision=CommType.BF16,\n",
        "            )\n",
        "        )\n",
        ")\n",
        "model = torchrec.distributed.DistributedModelParallel(ebc, sharders=[sharder], device=torch.device(\"cuda\"))"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": null,
      "metadata": {},
      "outputs": [],
      "source": []
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "slQSqiIxQHVW"
      },
      "source": [
        "### Query vanilla nn.EmbeddingBag with input and offsets\n",
        "\n",
        "We query [`nn.Embedding`](https://pytorch.org/docs/stable/generated/torch.nn.Embedding.html) and [`nn.EmbeddingBag`](https://pytorch.org/docs/stable/generated/torch.nn.EmbeddingBag.html) with `input` and `offsets`. Input is a 1-D tensor containing the lookup values. Offsets is a 1-D tensor where the sequence is a cumulative sum of the number of values to pool per example.\n",
        "\n",
        "Let's look at an example, recreating the product EmbeddingBag above\n",
        "\n",
        "```\n",
        "|------------|\n",
        "| product ID |\n",
        "|------------|\n",
        "| [101, 202] |\n",
        "| []         |\n",
        "| [303]      |\n",
        "|------------|\n",
        "```\n",
        "\n"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 7,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "4T_SExFDBqHS",
        "outputId": "1316b55d-3741-454c-f66b-dfbb000a6060"
      },
      "outputs": [
        {
          "data": {
            "text/plain": [
              "tensor([[ 0.5390, -0.2241,  0.8082, -0.5749,  0.1228, -0.9287, -0.7846, -0.3399,\n",
              "         -0.8951, -0.8286,  1.0466,  0.9958, -0.9805,  0.0507,  0.4695,  0.8182,\n",
              "         -0.5811, -0.8011,  0.0054,  0.6385, -0.6399, -0.3610,  0.4634, -0.7177,\n",
              "          0.0421, -0.5311,  0.6060,  0.8897, -0.9537, -0.9799, -0.0967, -0.6840,\n",
              "          0.9067,  0.4904,  0.1605, -0.1435, -1.0343, -0.9891, -0.1758,  0.6620,\n",
              "          1.0514, -0.7046,  0.0405, -1.0792,  0.7542, -0.1401, -0.3432, -1.5230,\n",
              "          0.0222, -0.2146,  1.1050,  0.3193, -1.2632,  0.6067, -0.3541,  0.7400,\n",
              "         -0.3682,  0.2123,  0.4429,  1.1721,  0.0340,  0.6423, -0.3339, -0.4168],\n",
              "        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n",
              "          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n",
              "          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n",
              "          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n",
              "          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n",
              "          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n",
              "          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n",
              "          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],\n",
              "        [-0.3995,  0.0784,  0.2586, -0.3972,  0.2170, -0.0424,  0.8071,  0.1516,\n",
              "         -0.3463,  1.1072,  1.2706,  0.6777, -1.0217,  1.1251,  2.0801, -0.2139,\n",
              "         -2.5678, -1.4960, -1.8183, -1.0340, -0.1477, -0.0319, -0.9335,  0.5077,\n",
              "          0.3768, -1.1825,  1.1409, -0.8289,  0.8265, -0.5644,  0.1615, -0.1176,\n",
              "         -0.6233,  0.2735, -0.0807,  0.8700, -0.1514, -0.9249,  0.5782, -0.0524,\n",
              "          1.4384,  0.8333,  1.2265,  0.2436,  1.0352,  0.9672, -0.5725,  1.0761,\n",
              "         -0.3872,  0.7274,  1.2550, -1.1618,  0.7739,  0.6375, -0.2184,  0.2534,\n",
              "         -0.2140,  1.3245, -0.8001,  1.0493,  1.5322,  0.3372, -0.3906, -1.1986]],\n",
              "       grad_fn=<EmbeddingBagBackward0>)"
            ]
          },
          "execution_count": 7,
          "metadata": {},
          "output_type": "execute_result"
        }
      ],
      "source": [
        "product_eb = torch.nn.EmbeddingBag(4096, 64)\n",
        "product_eb(input=torch.tensor([101, 202, 303]), offsets=torch.tensor([0, 2, 2]))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "FxFOoBnZCbRX"
      },
      "source": [
        "### Representing minibatches with KeyedJaggedTensor\n",
        "\n",
        "We need an efficient representation of multiple examples of an arbitrary number of entity IDs per feature per example. In order to enable this \"jagged\" representation, we use the TorchRec datastructure [`KeyedJaggedTensor`](https://pytorch.org/torchrec/torchrec.sparse.html#torchrec.sparse.jagged_tensor.JaggedTensor) (KJT).\n",
        "\n",
        "Let's take a look at **how to lookup a collection of two embedding bags**, \"product\" and \"user\".  Assume the minibatch is made up of three examples for three users. The first of which has two product IDs, the second with none, and the third with one product ID.\n",
        "\n",
        "```\n",
        "|------------|------------|\n",
        "| product ID | user ID    |\n",
        "|------------|------------|\n",
        "| [101, 202] | [404]      |\n",
        "| []         | [505]      |\n",
        "| [303]      | [606]      |\n",
        "|------------|------------|\n",
        "```\n",
        "\n",
        "The query should be:"
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 8,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "dKxjPYbpDY3k",
        "outputId": "cb6b3103-e5b7-4e05-edaa-4996a1bae080"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "KeyedJaggedTensor({\n",
            "    \"product\": [[101, 202], [], [303]],\n",
            "    \"user\": [[404], [505], [606]]\n",
            "})\n",
            "\n"
          ]
        }
      ],
      "source": [
        "mb = torchrec.KeyedJaggedTensor(\n",
        "    keys = [\"product\", \"user\"],\n",
        "    values = torch.tensor([101, 202, 303, 404, 505, 606]).cuda(),\n",
        "    lengths = torch.tensor([2, 0, 1, 1, 1, 1], dtype=torch.int64).cuda(),\n",
        ")\n",
        "\n",
        "print(mb.to(torch.device(\"cpu\")))"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "Co1Tb5RQ-J5a"
      },
      "source": [
        "Note that the KJT batch size is `batch_size = len(lengths)//len(keys)`. **In the above example, batch_size is 3.**\n",
        "\n"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ZjP4ctxqnmsU"
      },
      "source": [
        "### Putting it all together, querying our distributed model with a KJT minibatch\n",
        "Finally, we can query our model using our minibatch of products and users.\n",
        "\n",
        "The resulting lookup will contain a KeyedTensor, where each key (or feature) contains a 2D tensor of size 3x64 (batch_size x embedding_dim)."
      ]
    },
    {
      "cell_type": "code",
      "execution_count": 17,
      "metadata": {
        "colab": {
          "base_uri": "https://localhost:8080/"
        },
        "id": "TmnV3iH4IXn8",
        "outputId": "2dc38da9-0e30-4e76-dc80-47c9d80074c7"
      },
      "outputs": [
        {
          "name": "stdout",
          "output_type": "stream",
          "text": [
            "product embeddings tensor([[-0.0051,  0.0119, -0.0266,  0.0086, -0.0233, -0.0021,  0.0144,  0.0168,\n",
            "          0.0038, -0.0053, -0.0216, -0.0173, -0.0157,  0.0108, -0.0210, -0.0072,\n",
            "         -0.0108,  0.0071, -0.0076,  0.0295,  0.0257,  0.0107, -0.0049, -0.0035,\n",
            "         -0.0010, -0.0122, -0.0241,  0.0012,  0.0107,  0.0062,  0.0219,  0.0213,\n",
            "          0.0179,  0.0073, -0.0009,  0.0101,  0.0248, -0.0185,  0.0136,  0.0132,\n",
            "         -0.0290,  0.0144,  0.0048,  0.0006,  0.0079, -0.0225, -0.0007,  0.0117,\n",
            "          0.0105,  0.0014, -0.0198, -0.0026,  0.0194,  0.0116,  0.0114, -0.0049,\n",
            "          0.0173,  0.0022,  0.0141, -0.0245,  0.0044,  0.0022, -0.0061, -0.0049],\n",
            "        [ 0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n",
            "          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n",
            "          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n",
            "          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n",
            "          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n",
            "          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n",
            "          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,\n",
            "          0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000,  0.0000],\n",
            "        [-0.0042, -0.0089,  0.0031,  0.0150,  0.0030, -0.0103,  0.0143,  0.0121,\n",
            "          0.0092, -0.0038, -0.0022, -0.0063,  0.0128, -0.0022, -0.0131, -0.0019,\n",
            "         -0.0006,  0.0053,  0.0008,  0.0021,  0.0012, -0.0012,  0.0025,  0.0116,\n",
            "         -0.0155,  0.0051, -0.0032, -0.0041, -0.0105, -0.0014, -0.0024,  0.0052,\n",
            "          0.0061,  0.0004,  0.0124,  0.0124,  0.0136, -0.0034, -0.0075, -0.0059,\n",
            "         -0.0054, -0.0004, -0.0120, -0.0105, -0.0126, -0.0129,  0.0120,  0.0014,\n",
            "         -0.0037,  0.0026,  0.0145, -0.0053,  0.0095,  0.0136, -0.0058,  0.0030,\n",
            "          0.0023, -0.0032,  0.0037, -0.0074, -0.0030,  0.0027, -0.0046, -0.0123]],\n",
            "       device='cuda:0', grad_fn=<SplitWithSizesBackward0>)\n",
            "user embeddings tensor([[ 1.7781e-03, -1.4074e-02, -1.0030e-02, -8.7433e-03, -6.4829e-03,\n",
            "         -1.1644e-02,  3.1367e-03,  1.4910e-02,  1.3202e-02, -3.8754e-03,\n",
            "         -1.4966e-02,  2.1007e-03, -7.4155e-03, -4.5403e-03, -6.7812e-03,\n",
            "          6.3994e-03,  1.2086e-02,  1.2878e-02,  1.4239e-02, -1.2684e-02,\n",
            "         -1.1834e-02,  1.3470e-03, -1.0714e-02, -4.4043e-04,  1.1827e-02,\n",
            "         -1.5078e-02, -1.1874e-02, -1.2019e-02,  1.1807e-02, -7.4412e-03,\n",
            "         -1.0427e-03, -1.1670e-02,  1.2619e-02, -5.7154e-03, -4.1964e-04,\n",
            "          7.4416e-03, -4.8433e-03, -2.4990e-03, -4.5710e-03,  1.6587e-03,\n",
            "          9.4574e-03, -1.3659e-02,  1.4636e-02, -1.0672e-03,  7.5791e-03,\n",
            "         -6.5686e-03,  2.4470e-03, -1.3773e-02,  1.3362e-02, -1.2804e-02,\n",
            "         -2.6700e-03,  8.1768e-03, -8.8283e-03, -5.0410e-03, -4.3076e-03,\n",
            "          1.3206e-02, -1.3843e-02,  4.1662e-03, -1.0851e-02, -3.7601e-03,\n",
            "         -2.9842e-03, -1.2234e-02, -1.3067e-02,  2.4397e-05],\n",
            "        [-1.1139e-02, -1.1935e-02, -9.5498e-03, -9.3701e-03,  1.5218e-02,\n",
            "         -1.3971e-02,  1.0204e-02, -1.3939e-02, -9.6123e-03,  3.5847e-03,\n",
            "          2.3046e-03,  8.8443e-03, -1.8695e-03,  1.2977e-02, -1.2108e-02,\n",
            "          6.0886e-03, -1.9408e-03, -3.1992e-03,  8.6488e-03,  1.4559e-02,\n",
            "          2.4447e-03, -1.1202e-02,  6.1081e-03, -2.5672e-03, -1.3761e-02,\n",
            "          8.5980e-03, -1.2189e-02,  1.1540e-02,  8.1158e-03,  1.0525e-02,\n",
            "         -1.4594e-02, -8.7755e-03,  6.4152e-03, -3.6499e-03, -1.0631e-02,\n",
            "          2.5016e-03, -7.7545e-03, -1.0751e-03,  1.5677e-03, -4.2500e-03,\n",
            "         -2.9060e-03, -1.4003e-02,  4.7707e-04, -5.4636e-03, -2.1258e-03,\n",
            "         -1.4120e-02, -8.9897e-04, -1.3123e-02, -6.0624e-03,  1.3688e-02,\n",
            "          2.1036e-03, -2.0213e-03, -1.0417e-02, -3.4156e-03,  1.2653e-02,\n",
            "          1.5182e-02,  1.2269e-02,  3.2597e-03, -1.0406e-02,  5.0531e-04,\n",
            "         -6.5279e-03,  9.1205e-03,  1.4822e-02,  2.2662e-03],\n",
            "        [-5.7377e-03, -8.1668e-03, -6.3326e-03,  1.0612e-02, -2.1545e-04,\n",
            "          1.6740e-03,  1.3960e-02, -1.4884e-02,  1.9036e-03, -6.9326e-03,\n",
            "          4.8978e-03, -1.4179e-02, -7.6593e-04,  1.4389e-02, -2.8147e-03,\n",
            "          3.3961e-03, -1.3536e-02, -5.9782e-03,  5.8542e-03,  2.0505e-03,\n",
            "          1.4506e-02,  9.5605e-03, -6.7243e-03, -3.5396e-03,  9.8425e-03,\n",
            "          6.2296e-03,  5.9379e-03, -2.6117e-03,  9.5100e-03, -5.9586e-05,\n",
            "         -2.5985e-03, -2.4412e-03,  1.0931e-02,  1.2986e-02, -1.0906e-02,\n",
            "          3.3439e-03, -2.4703e-04,  1.0547e-02,  1.0157e-02,  5.2837e-03,\n",
            "         -6.3240e-03, -1.0825e-02, -4.3349e-03,  5.3597e-03, -1.2550e-02,\n",
            "         -1.3413e-02, -6.4192e-03,  2.6606e-03,  7.5830e-03,  3.1177e-04,\n",
            "          1.2622e-02, -7.0833e-04,  1.3353e-02, -1.2675e-02, -1.2349e-03,\n",
            "          9.4488e-03, -8.5425e-03, -2.6770e-03,  8.0133e-03, -1.3499e-02,\n",
            "         -8.6746e-03,  1.2764e-03,  1.5327e-02,  5.8484e-04]], device='cuda:0',\n",
            "       grad_fn=<SplitWithSizesBackward0>)\n"
          ]
        }
      ],
      "source": [
        "pooled_embeddings = model(mb).to_dict()\n",
        "print(\"product embeddings\", pooled_embeddings[\"product\"])\n",
        "print(\"user embeddings\", pooled_embeddings[\"user\"])"
      ]
    },
    {
      "cell_type": "markdown",
      "metadata": {
        "id": "ebXfh7oW9fHH"
      },
      "source": [
        "## More resources\n",
        "For more information, please see our [dlrm](https://github.com/facebookresearch/dlrm/tree/main/torchrec_dlrm/) example, which includes multinode training on the criteo terabyte dataset, using Meta’s [DLRM](https://arxiv.org/abs/1906.00091)."
      ]
    }
  ],
  "metadata": {
    "custom": {
      "cells": [],
      "metadata": {
        "accelerator": "GPU",
        "colab": {
          "background_execution": "on",
          "collapsed_sections": [],
          "machine_shape": "hm",
          "name": "Torchrec Introduction.ipynb",
          "provenance": []
        },
        "fileHeader": "",
        "fileUid": "c9a29462-2509-4adb-a539-0318cf56bb00",
        "interpreter": {
          "hash": "d4204deb07d30e7517ec64733b2d65f24aff851b061e21418071854b06459363"
        },
        "isAdHoc": false,
        "kernelspec": {
          "display_name": "Python 3.7.13 ('torchrec': conda)",
          "language": "python",
          "name": "python3"
        },
        "language_info": {
          "codemirror_mode": {
            "name": "ipython",
            "version": 3
          },
          "file_extension": ".py",
          "mimetype": "text/x-python",
          "name": "python",
          "nbconvert_exporter": "python",
          "pygments_lexer": "ipython3",
          "version": "3.7.13"
        }
      },
      "nbformat": 4,
      "nbformat_minor": 0
    },
    "indentAmount": 2
  },
  "nbformat": 4,
  "nbformat_minor": 2
}
